Skip to content

Conversation

@IgWod-IMG
Copy link
Contributor

With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example):

<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>

The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region.

A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with --test-spirv-roundtrip. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place.

This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic.

cc @mishaobu

@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

Changes

With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example):

&lt;unknown&gt;:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
&lt;unknown&gt;:0: note: see current operation: %0 = "spirv.Variable"()&lt;{storage_class = #spirv.storage_class&lt;Function&gt;}&gt; : () -&gt; !spirv.ptr&lt;vector&lt;4xf32&gt;, Function&gt;

The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region.

A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with --test-spirv-roundtrip. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place.

This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic.

cc @mishaobu


Full diff: https://github.com/llvm/llvm-project/pull/127639.diff

6 Files Affected:

  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+42)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+4)
  • (modified) mlir/test/Target/SPIRV/loop.mlir (+2)
  • (modified) mlir/test/Target/SPIRV/selection.mlir (+2)
  • (added) mlir/test/Target/SPIRV/selection.spv (+40)
  • (modified) mlir/test/lit.cfg.py (+1)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819..ebf2ecee3207a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2158,6 +2158,39 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
   return success();
 }
 
+LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+  auto splitBlock = [&](Block *block) {
+    // Do not split loop headers
+    if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
+      if (it->second.continueBlock) {
+        return;
+      }
+    }
+
+    if (!block->mightHaveTerminator())
+      return;
+
+    auto terminator = block->getTerminator();
+    assert(terminator != nullptr);
+
+    if (isa<spirv::BranchConditionalOp>(terminator) &&
+        std::distance(block->begin(), block->end()) > 1) {
+      auto newBlock = block->splitBlock(terminator);
+      OpBuilder builder(block, block->end());
+      builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
+
+      if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
+        auto value = std::move(it->second);
+        blockMergeInfo.erase(it);
+        blockMergeInfo.try_emplace(newBlock, std::move(value));
+      }
+    }
+  };
+  curFunction->walk(splitBlock);
+
+  return success();
+}
+
 LogicalResult spirv::Deserializer::structurizeControlFlow() {
   LLVM_DEBUG({
     logger.startLine()
@@ -2165,6 +2198,15 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
     logger.indent();
   });
 
+  LLVM_DEBUG({
+    logger.startLine() << "[cf] split conditional blocks\n";
+    logger.startLine() << "\n";
+  });
+
+  if (failed(splitConditionalBlocks())) {
+    return failure();
+  }
+
   while (!blockMergeInfo.empty()) {
     Block *headerBlock = blockMergeInfo.begin()->first;
     BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 264d580c40f09..8dd35aa876726 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,6 +246,10 @@ class Deserializer {
     return opBuilder.getStringAttr(attrName);
   }
 
+  // Move a conditional branch into a separate basic block to avoid sinking
+  // defs that are required outside a selection region.
+  LogicalResult splitConditionalBlocks();
+
   //===--------------------------------------------------------------------===//
   // Type
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..dd0d3e1af19dc 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -267,6 +267,8 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
       }
 // CHECK-NEXT:       %[[LOAD:.+]] = spirv.Load "Function" %[[VAR]] : i1
       %load = spirv.Load "Function" %var : i1
+// CHECK-NEXT:       spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT:     ^[[BB]]
 // CHECK-NEXT:       spirv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]]
       spirv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge
 // CHECK-NEXT:     ^[[CONTINUE]](%[[ARG2:.+]]: i64):
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index f1d35d74dba15..24abb12998d06 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %var = spirv.Variable : !spirv.ptr<i1, Function>
 // CHECK-NEXT:    spirv.Branch ^[[BB:.+]]
 // CHECK-NEXT:  ^[[BB]]:
+// CHECK:    spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT:  ^[[BB]]:
 
 // CHECK-NEXT:    spirv.mlir.selection {
     spirv.mlir.selection {
diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv
new file mode 100644
index 0000000000000..b96e839f5c805
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection.spv
@@ -0,0 +1,40 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+; CHECK: spirv.module
+               OpCapability Shader
+          %2 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %colorOut
+               OpExecutionMode %main OriginUpperLeft
+               OpDecorate %colorOut Location 0
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+    %float_1 = OpConstant %float 1
+    %float_0 = OpConstant %float 0
+         %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+   %colorOut = OpVariable %out_v4float Output   
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+  %out_float = OpTypePointer Output %float
+       %bool = OpTypeBool
+         %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+       %main = OpFunction %void None %4
+          %6 = OpLabel
+      %color = OpVariable %fun_v4float Function
+               OpStore %color %13
+         %19 = OpAccessChain %out_float %colorOut %uint_0
+         %20 = OpLoad %float %19
+         %22 = OpFOrdEqual %bool %20 %float_1
+               OpSelectionMerge %24 None
+               OpBranchConditional %22 %23 %24
+         %23 = OpLabel
+               OpStore %color %25
+               OpBranch %24
+         %24 = OpLabel
+         %26 = OpLoad %v4float %color
+               OpStore %colorOut %26
+               OpReturn
+               OpFunctionEnd
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 32b2f8b53d5fa..c447a047eea89 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -43,6 +43,7 @@
     ".test",
     ".pdll",
     ".c",
+    ".spv"
 ]
 
 # test_source_root: The root path where tests are located.

assert(terminator != nullptr);

if (isa<spirv::BranchConditionalOp>(terminator) &&
std::distance(block->begin(), block->end()) > 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do llvm::size(*block) > 1

Copy link
Contributor Author

@IgWod-IMG IgWod-IMG Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to work:

note: candidate template ignored: requirement 'std::is_base_of<std::random_access_iterator_tag, std::bidirectional_iterator_tag>::value' was not satisfied [with R = Block &]

I tried few other things, but it seems so far that std::distance is the cleanest solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's not random access then should we check that begin != end || begin +1 != end to avoid iterating over the whole thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, thanks! I'll update the code.

@@ -0,0 +1,40 @@
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
; CHECK: spirv.module
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining what this test attempts to exercise? Would it be possible to check something more beyond that we get a module back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main purpose of this test is to check that deserialization doesn't fail on control flow structurization. I'll add a comment to explain it. I only checked for module, as the test would fail before anything is generated, but I guess it's an opportunity to also check if the selection region is generated correctly, so I'll add few more checks.

// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
// CHECK: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for following up on this objective! :)

Just a bit concerned about the added overhead here and in mlir/test/Target/SPIRV/loop.mlir. Is this creating extra unused branches?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, but I don't think it'll be much of an issue outside running --test-spirv-roundtrip that is only used for testing. The splitting only happens in the deserializer, so it won't affect any lowering that use SPIR-V, etc. Then when it comes to deserialization, it shouldn't be an issue when going directly from SPIR-V generated from outside MLIR (GLSL, etc.). I would expect the code to always have something meaningful to split without introducing superfluous blocks.

Now why do we get extra blocks with the roundtrip? I serialized the MLIR code from the test you commented and got (trimmed):

          %4 = OpLabel
         %12 = OpVariable %_ptr_Function_bool Function
               OpBranch %13
         %13 = OpLabel
               OpSelectionMerge %16 None
               OpBranchConditional %true %14 %15
         %14 = OpLabel
               OpStore %12 %true
               OpBranch %16
         %15 = OpLabel
               OpStore %12 %false
               OpBranch %16
         %16 = OpLabel

After serializing we are getting an extra block (%13) that would be unlikely to be present in non-MLIR generated SPIR-V, as OpSeletionMerge would be part of the predeceasing block. Actually, if you think about this is something my patch do, it isolates OpSeletionMerge and OpBranchConditional, so it shows my approach in deserialization matches how serializer works. Now because the block is already split, deserializng it again does further splitting creating superfluous blocks.

So, yes extra blocks are possible, but I think that would only happen if the input SPIR-V is already split and I don't think that would happen often with the upstream code. But even if it does happen, this only introduces some direct branches, which I believe are easy to optimise somewhere down the line - just collapse blocks together.

Hope that makes sense!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate the rigor, fwiw my entire workflow relies on doing spirv roundtrip.

Basically I am perf sensitive, and want to do spirv instrumentation & optimization passes at MLIR level before going back to spirv for deployment & this currently works very well.

Tested your PR with my use case, and figured I'd throw out a demo of what happens after roundtrip (easier to read as glsl):

Starting point:

float raymarchwater(vec3 camera, vec3 start, vec3 end, float depth) {
    vec3 pos = start;
    vec3 dir = normalize(end - start);
    for(int i=0; i < 64; i++) {
        float height = getwaves(pos.xz, ITERATIONS_RAYMARCH) * depth - depth;
        if(height + 0.01 > pos.y) {
            return distance(pos, camera);
        }
        pos += dir * (pos.y - height);
    }
    return distance(start, camera);
}

Roundtrip from #123371 [My draft]:

highp float raymarchwater(vec3 _201, vec3 _202, vec3 _203, float _204)
{
    vec3 _206 = _202;
    vec3 _207 = normalize(_203 - _202);
    for (int _208 = 0; _208 < 64; _208++)
    {
        vec2 _210 = _206.xz;
        int _211 = 12;
        float _232 = getwaves(_210, _211);
        float _209 = (_232 * _204) - _204;
        if ((_209 + 0.00999999977648258209228515625) > _206.y)
        {
            return distance(_206, _201);
        }
        _206 += (_207 * (_206.y - _209));
    }
    return distance(_202, _201);
}

Roundtrip from this PR:

highp float raymarchwater(vec3 _202, vec3 _203, vec3 _204, float _205)
{
    vec3 _207 = _203;
    vec3 _208 = normalize(_204 - _203);
    int _209 = 0;
    for (;;)
    {
        if (_209 < 64)
        {
            vec2 _211 = _207.xz;
            int _212 = 12;
            float _232 = getwaves(_211, _212);
            float _210 = (_232 * _205) - _205;
            if ((_210 + 0.00999999977648258209228515625) > _207.y)
            {
                return distance(_207, _202);
            }
            _207 += (_208 * (_207.y - _210));
            _209++;
            continue;
        }
        else
        {
            break;
        }
    }
    return distance(_203, _202);
}

To my naiive understanding, this still seems concerning, but eager to defer to your or @kuhar 's judgement on the matter -- perhaps I need to perform a loop detection pass if your version gets merged(?) or perhaps my perceived worry is actually just insignificant.

Copy link
Contributor Author

@IgWod-IMG IgWod-IMG Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing it. I am a bit confused, because I didn't intend to affect loops, so I started looking into a simpler example and it seems to me that the loop continue block (https://mlir.llvm.org/docs/Dialects/SPIR-V/#loop) gets split into two. It wasn't my intention:

    // Do not split loop headers
    if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
      if (it->second.continueBlock) {
        return;
      }
    }

(I incorrectly called it a loop header here)

Let me investigate what's happing first and then I'll come back to you. There is no point in engaging into a deeper discussion when the problem may lie in an incorrect implementation :)

EDIT: Actually, this code may be doing what it intended (as header can have 2 outgoing edges) and the comment is correct, and I may need to handle continue blocks as well. Anyway, I need to re-think what I have done here.

@IgWod-IMG
Copy link
Contributor Author

I have just pushed an updated patch. It addresses autos and makes the test a bit better.

Comment on lines 7 to 29
; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []>
OpCapability Shader
%2 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %colorOut
OpExecutionMode %main OriginUpperLeft
OpDecorate %colorOut Location 0
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%fun_v4float = OpTypePointer Function %v4float
%float_1 = OpConstant %float 1
%float_0 = OpConstant %float 0
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
%out_v4float = OpTypePointer Output %v4float
%colorOut = OpVariable %out_v4float Output
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%out_float = OpTypePointer Output %float
%bool = OpTypeBool
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
; CHECK: spirv.func @main() "None" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put these checks in a separate block, either before or after the spirv module? I find it hard to leave when it's interleaved like this

@IgWod-IMG
Copy link
Contributor Author

I have replaced std::distance() and updated test, so hopefully it's more readable now.

Regarding the bigger issue. I have redesigned the code so now it only splits conditional blocks that are headers of selection regions. Loops are not touched. Looking at loops (https://mlir.llvm.org/docs/Dialects/SPIR-V/#loop) the entry block will always be unconditional, so splitting would not be applied here. Then conditional blocks that form the loop header and continue will be fully sunk anyway, regardless of splitting, so we can skip splitting them.

So, to sum up. This PR only targets selection regions (which so far were the only problem on my side when it comes to sinking), and if we need anything for loops, we can work it out in the future.

@mishaobu please let me know if the new patch causes any issues. Loops should be untouched now, so they should keep their structure (unless something else is going on). The if statement looked the same in both code snippets, so I hope splitting doesn't cause any issues there. Just to verify, I attempted to run your raymarchwater (GLSL -> SPIR-V -> MLIR -> SPIR-V -> GLSL) and I got the function as you would with your patch. Let me know your thoughts.

There is still an extra block in one of the tests:

  spirv.func @selection_cf() "None" {
    %0 = spirv.Variable : !spirv.ptr<i1, Function>
    spirv.Branch ^bb1
  ^bb1:  // pred: ^bb0
    %true = spirv.Constant true
    spirv.Branch ^bb2
  ^bb2:  // pred: ^bb1
    spirv.mlir.selection {
      spirv.BranchConditional %true, ^bb1, ^bb2
    ^bb1:  // pred: ^bb0
      %true_1 = spirv.Constant true
      spirv.Store "Function" %0, %true_1 : i1
      spirv.Branch ^bb3
    ^bb2:  // pred: ^bb0
      %false = spirv.Constant false
      spirv.Store "Function" %0, %false : i1
      spirv.Branch ^bb3
    ^bb3:  // 2 preds: ^bb1, ^bb2
      spirv.mlir.merge
    }

Compared against:

  spirv.func @selection_cf() "None" {
    %0 = spirv.Variable : !spirv.ptr<i1, Function>
    spirv.Branch ^bb1
  ^bb1:  // pred: ^bb0
    spirv.mlir.selection {
      %true = spirv.Constant true
      spirv.BranchConditional %true, ^bb1, ^bb2
    ^bb1:  // pred: ^bb0
      %true_1 = spirv.Constant true
      spirv.Store "Function" %0, %true_1 : i1
      spirv.Branch ^bb3
    ^bb2:  // pred: ^bb0
      %false = spirv.Constant false
      spirv.Store "Function" %0, %false : i1
      spirv.Branch ^bb3
    ^bb3:  // 2 preds: ^bb1, ^bb2
      spirv.mlir.merge
    }

It happens because the %true constant gets pushed outside the selection region. Arguably this works as intended; the header block gets split. From my perspective this is a rather minor issue (that will probably get optimised at some later stages anyway), and I think a price worth paying for a more robust deserialization. And improving it is something that can be addressed in the future. @kuhar do you have any thoughts on that?

Nevertheless, if we cannot reach a consensus, I'm happy to look for an alternative fix. Potentially taking @mishaobu implementation and getting it into a mergeable state.

}

LogicalResult spirv::Deserializer::splitConditionalBlocks() {
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); it++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); it++) {
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {

See https://llvm.org/docs/CodingStandards.html#prefer-preincrement

Comment on lines 2163 to 2164
// Skip processing loop regions. For loop regions continueBlock is non-null.
if (it->second.continueBlock)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do auto& [x, y] = *it to give first and second some meaningful names.

@IgWod-IMG
Copy link
Contributor Author

I have addressed comments above. Also, I noticed that the loop used return instead of continue, so I corrected that (it was a leftover from the original code). I also split the final condition into two, as I think that falls under: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fixes

continue;

Operation *terminator = block->getTerminator();
assert(terminator != nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(terminator != nullptr);
assert(terminator);

@IgWod-IMG
Copy link
Contributor Author

So unfortunately, I don't think we should merge this change yet. I found some issues with this patch when running more complex shaders. In case of nested selection regions, the block splitting needs to account for the fact that the split block can be a mergeBlock of another region and needs to update the blockMergeInfo - this should be a relatively easy fix. What's more worrying I stared to see some non-deterministic behaviour, so I need some time to dig a bit more.

@IgWod-IMG
Copy link
Contributor Author

I have pushed an updated patch. blockMergeInfo is now correctly maintained in the splitting function. I also found an issue with non-deterministic crashes. The non-determinism itself is not fixed (discussed below), but the failures I started to see were due to me modifying the DenseMap key within the loop - instead of reinserting the element - breaking the internals of the data structure – silly me.

@kuhar if you're happy with the updated code then it's ready. Should we wait for @mishaobu to okay it before we merge it?

Now, the "issues" that is still present (although it doesn't cause any crashes with this patch) is that the deserialzier is non-deterministic due to use of a DenseMap. The selection and loop regions are sunk by iterating the DenseMap and the order of items within map may vary from run to run, as theBlock pointer is used as a key. This doesn’t cause any issues with supported shaders; however, it makes debugging more difficult when things fail. Anyway, it’s outside the scope of this PR.

@kuhar
Copy link
Member

kuhar commented Feb 24, 2025

@kuhar if you're happy with the updated code then it's ready. Should we wait for @mishaobu to okay it before we merge it?

Seems like they have some stake in this code, so yes.

Now, the "issues" that is still present (although it doesn't cause any crashes with this patch) is that the deserialzier is non-deterministic due to use of a DenseMap. The selection and loop regions are sunk by iterating the DenseMap and the order of items within map may vary from run to run, as theBlock pointer is used as a key. This doesn’t cause any issues with supported shaders; however, it makes debugging more difficult when things fail. Anyway, it’s outside the scope of this PR.

Could you open an issue for this and add a TODO that links to it in the code?

With the current design some of the values are sank into a selection region,
despite them being also used outside that region. This is because the current
deserializer logic sinks the entire basic block containing a conditional branch
forming a header of a selection construct, without accounting for some values
being used outside. This manifests as (for example):

```
<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>
```

The proposed solution to this problem is to split the conditional basic block
into two, one block containing just the conditional branch, and other the rest
of instructions. By doing this, the logic that structures selection regions,
only sinks the comparison, keeping the rest of instructions outside the
selection region.

A SPIR-V test is required, as the problem can happen only during
deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR
test exhibiting the problematic behaviour would be an incorrect MLIR in the
first place.

This solution is proposed as an alternative to an unfinished PR llvm#123371, that
is unlikely to be merged in the foreseeable future, as the author "stepped away
from this for a time being". There is also a Discourse thread:
https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494
that tried to solicit some feedback on the topic.
@IgWod-IMG
Copy link
Contributor Author

I have created an issue (#128547) and added a TODO.

@mishaobu
Copy link
Contributor

LGTM; this bridges a key gap in SPIRV <-> MLIR support. Nice work!

@kuhar kuhar merged commit 594919c into llvm:main Feb 24, 2025
11 checks passed
@IgWod-IMG IgWod-IMG deleted the img_split-cond-bb branch February 25, 2025 09:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment